{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 循环神经网络—TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入 MNIST 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-3f5d8b810895>:1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-images-idx3-ubyte.gz\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
     ]
    }
   ],
   "source": [
    "mnist = input_data.read_data_sets('./datasets/ch08/tensorflow/MNIST',one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(55000, 784)\n",
      "(55000, 10)\n"
     ]
    }
   ],
   "source": [
    "print(mnist.train.images.shape)\n",
    "print(mnist.train.labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 784)\n",
      "(5000, 10)\n"
     ]
    }
   ],
   "source": [
    "print(mnist.validation.images.shape)\n",
    "print(mnist.validation.labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 784)\n",
      "(10000, 10)\n"
     ]
    }
   ],
   "source": [
    "print(mnist.test.images.shape)\n",
    "print(mnist.test.labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义占位符 placeholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 参数设置\n",
    "batch_size = 100        # BATCH 的大小，相当于一次处理100个image\n",
    "time_step = 28          # 一个LSTM中，输入序列的长度Tx，image有28行\n",
    "input_size = 28         # 单个x向量长度，image有28列\n",
    "lr = 0.001              # 学习率\n",
    "num_units = 100         # 隐藏层多少个LTSM单元\n",
    "iterations =1000        # 迭代训练次数\n",
    "classes =10             # 输出大小，0-9十个数字的概率\n",
    "\n",
    "# 定义 placeholders\n",
    "# 维度是[batch_size，time_step * input_size]\n",
    "x = tf.placeholder(tf.float32, [None, time_step * input_size]) \n",
    "# 输入的是二维数据，将其还原为三维，维度是[batch_size, time_step, input_size]\n",
    "x_image = tf.reshape(x, [-1, time_step, input_size])                   \n",
    "y = tf.placeholder(tf.int32, [None, classes])                     "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义RNN（LSTM）结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=num_units) \n",
    "outputs,final_state = tf.nn.dynamic_rnn(\n",
    "    cell=rnn_cell,              # 选择传入的cell\n",
    "    inputs=x_image,             # 传入的数据\n",
    "    initial_state=None,         # 初始状态\n",
    "    dtype=tf.float32,           # 数据类型\n",
    "    time_major=False,           # False: (batch, time_step, input); True: (time_step, batch, input)，这里根据x_image结构选择False\n",
    ")\n",
    "output = tf.layers.dense(inputs=outputs[:, -1, :], units=classes)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义损失函数与优化算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=output)      # 计算loss\n",
    "train_step = tf.train.AdamOptimizer(lr).minimize(cross_entropy)      #选择优化方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#判断预测标签和实际标签是否匹配\n",
    "correct_prediction = tf.equal(tf.argmax(y, axis=1),tf.argmax(output, axis=1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))  #计算正确率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练并验证准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train accuracy 0.560\n",
      "train accuracy 0.740\n",
      "train accuracy 0.900\n",
      "train accuracy 0.810\n",
      "train accuracy 0.940\n",
      "train accuracy 0.890\n",
      "train accuracy 0.880\n",
      "train accuracy 0.910\n",
      "train accuracy 0.910\n",
      "train accuracy 0.930\n",
      "train accuracy 0.870\n",
      "train accuracy 0.960\n",
      "train accuracy 0.960\n",
      "train accuracy 0.940\n",
      "train accuracy 0.930\n",
      "train accuracy 0.960\n",
      "train accuracy 0.940\n",
      "train accuracy 0.980\n",
      "train accuracy 0.920\n",
      "train accuracy 0.980\n",
      "test accuracy 0.958\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)\n",
    "\n",
    "for i in range(iterations):\n",
    "    batch_x, batch_y = mnist.train.next_batch(batch_size)\n",
    "    sess.run(train_step, feed_dict={x: batch_x, y: batch_y})\n",
    "    if (i+1) % 50 == 0:\n",
    "        print(\"train accuracy %.3f\" % accuracy.eval(session = sess,\n",
    "                feed_dict = {x:batch_x, y:batch_y}))\n",
    "print(\"test accuracy %.3f\" % accuracy.eval(session = sess,\n",
    "    feed_dict = {x:mnist.test.images, y:mnist.test.labels}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
